2c783f
@@ -69,6 +69,7 @@
   private final Configuration conf;
   private final JobConf jobConf;
   private final MRInputUserPayloadProto userPayloadProto;
+  private final MapWork work;
   private final SplitGrouper splitGrouper = new SplitGrouper();
 
 
@@ -86,7 +87,7 @@
public HiveSplitGenerator(InputInitializerContext initializerContext) throws IOE
     // Read all credentials into the credentials instance stored in JobConf.
     ShimLoader.getHadoopShims().getMergedCredentials(jobConf);
 
-    MapWork work = Utilities.getMapWork(jobConf);
+    this.work = Utilities.getMapWork(jobConf);
 
     // Events can start coming in the moment the InputInitializer is created. The pruner
     // must be setup and initialized here so that it sets up it's structures to start accepting events.
@@ -98,58 +99,64 @@
public HiveSplitGenerator(InputInitializerContext initializerContext) throws IOE
 
   @Override
   public List<Event> initialize() throws Exception {
-    boolean sendSerializedEvents =
-        conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true);
-
-    // perform dynamic partition pruning
-    pruner.prune();
-
-    InputSplitInfoMem inputSplitInfo = null;
-    String realInputFormatName = conf.get("mapred.input.format.class");
-    boolean groupingEnabled = userPayloadProto.getGroupingEnabled();
-    if (groupingEnabled) {
-      // Need to instantiate the realInputFormat
-      InputFormat<?, ?> inputFormat =
-          (InputFormat<?, ?>) ReflectionUtils.newInstance(JavaUtils.loadClass(realInputFormatName),
-              jobConf);
-
-      int totalResource = getContext().getTotalAvailableResource().getMemory();
-      int taskResource = getContext().getVertexTaskResource().getMemory();
-      int availableSlots = totalResource / taskResource;
-
-      // Create the un-grouped splits
-      float waves =
-          conf.getFloat(TezMapReduceSplitsGrouper.TEZ_GROUPING_SPLIT_WAVES,
-              TezMapReduceSplitsGrouper.TEZ_GROUPING_SPLIT_WAVES_DEFAULT);
-
-      InputSplit[] splits = inputFormat.getSplits(jobConf, (int) (availableSlots * waves));
-      LOG.info("Number of input splits: " + splits.length + ". " + availableSlots
-          + " available slots, " + waves + " waves. Input format is: " + realInputFormatName);
-
-      Multimap<Integer, InputSplit> groupedSplits =
-          splitGrouper.generateGroupedSplits(jobConf, conf, splits, waves, availableSlots);
-      // And finally return them in a flat array
-      InputSplit[] flatSplits = groupedSplits.values().toArray(new InputSplit[0]);
-      LOG.info("Number of grouped splits: " + flatSplits.length);
-
-      List<TaskLocationHint> locationHints = splitGrouper.createTaskLocationHints(flatSplits);
+    // Setup the map work for this thread. Pruning modified the work instance to potentially remove
+    // partitions. The same work instance must be used when generating splits.
+    Utilities.setMapWork(jobConf, work);
+    try {
+      boolean sendSerializedEvents =
+          conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true);
+
+      // perform dynamic partition pruning
+      pruner.prune();
+
+      InputSplitInfoMem inputSplitInfo = null;
+      String realInputFormatName = conf.get("mapred.input.format.class");
+      boolean groupingEnabled = userPayloadProto.getGroupingEnabled();
+      if (groupingEnabled) {
+        // Need to instantiate the realInputFormat
+        InputFormat<?, ?> inputFormat =
+            (InputFormat<?, ?>) ReflectionUtils
+                .newInstance(JavaUtils.loadClass(realInputFormatName),
+                    jobConf);
+
+        int totalResource = getContext().getTotalAvailableResource().getMemory();
+        int taskResource = getContext().getVertexTaskResource().getMemory();
+        int availableSlots = totalResource / taskResource;
+
+        // Create the un-grouped splits
+        float waves =
+            conf.getFloat(TezMapReduceSplitsGrouper.TEZ_GROUPING_SPLIT_WAVES,
+                TezMapReduceSplitsGrouper.TEZ_GROUPING_SPLIT_WAVES_DEFAULT);
+
+        InputSplit[] splits = inputFormat.getSplits(jobConf, (int) (availableSlots * waves));
+        LOG.info("Number of input splits: " + splits.length + ". " + availableSlots
+            + " available slots, " + waves + " waves. Input format is: " + realInputFormatName);
+
+        Multimap<Integer, InputSplit> groupedSplits =
+            splitGrouper.generateGroupedSplits(jobConf, conf, splits, waves, availableSlots);
+        // And finally return them in a flat array
+        InputSplit[] flatSplits = groupedSplits.values().toArray(new InputSplit[0]);
+        LOG.info("Number of grouped splits: " + flatSplits.length);
+
+        List<TaskLocationHint> locationHints = splitGrouper.createTaskLocationHints(flatSplits);
+
+        inputSplitInfo =
+            new InputSplitInfoMem(flatSplits, locationHints, flatSplits.length, null, jobConf);
+      } else {
+        // no need for grouping and the target #of tasks.
+        // This code path should never be triggered at the moment. If grouping is disabled,
+        // DAGUtils uses MRInputAMSplitGenerator.
+        // If this is used in the future - make sure to disable grouping in the payload, if it isn't already disabled
+        throw new RuntimeException(
+            "HiveInputFormat does not support non-grouped splits, InputFormatName is: "
+                + realInputFormatName);
+        // inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(jobConf, false, 0);
+      }
 
+      return createEventList(sendSerializedEvents, inputSplitInfo);
+    } finally {
       Utilities.clearWork(jobConf);
-
-      inputSplitInfo =
-          new InputSplitInfoMem(flatSplits, locationHints, flatSplits.length, null, jobConf);
-    } else {
-      // no need for grouping and the target #of tasks.
-      // This code path should never be triggered at the moment. If grouping is disabled,
-      // DAGUtils uses MRInputAMSplitGenerator.
-      // If this is used in the future - make sure to disable grouping in the payload, if it isn't already disabled
-      throw new RuntimeException(
-          "HiveInputFormat does not support non-grouped splits, InputFormatName is: "
-              + realInputFormatName);
-      // inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(jobConf, false, 0);
     }
-
-    return createEventList(sendSerializedEvents, inputSplitInfo);
   }
 
 
